-
Notifications
You must be signed in to change notification settings - Fork 27.3k
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
A big improvement for dtype casting system with fp8 storage type and manual cast #14031
Conversation
Since norm layer need fp32, I only convert the linear operation layer(conv2d/linear) And TE have some pytorch function not support bf16 amp in CPU. I add a condition to indicate if the autocast is for unet.
great work |
Could the Lora issue be solved by merging the lora weights before converting the model to FP8 for inference? It would require the model to be reloaded every time a lora is changed but if the fp16 version is cached in ram I think it could still be fast enough to be worth it. |
Yes it could be solved |
Sorry I'm wrong, it is quite hard to be solved since lora ext actually don't know when it should load the fp16 weights. I have an idea is to cache the fp16 weight in fp8-layers directly (in CPU) but it will require more sys ram. Don't know if you think it is ok. (and definitely, it will be an option) |
@BetaDoggo I have added a mechanism to cache/restore fp16 weight when needed. This will require 5G more system ram to achive it on SDXL. But it do give us closer result with fp8 when we using lora: |
M2 Pro 16gb RAM Same prompt + sampler + seed + plugins (too tired to describe everything). Main branch: Using 16gb + 10gb swap Bravo! |
Just wanted to give some feedback that I've been using your branch for a couple of weeks now and have had no problems at all (CUDA 4080). I'm finding minimal quality loss versus BF16 and generation speeds pretty much on par with LCM/Turbo models and LORAs at 16bit. |
Hi @KohakuBlueleaf, Is fp8 going to be in v1.7.0? I also tried the fp8 settings option, and I get a non-reproducible image during the first switch from fp16 to fp8, when applying changes in settings: |
|
vid.mp4when I have fp8 enabled in settings, then exit the program. Then open the program, and the random seed distribution is entirely different. GPU seed is affected, not CPU seed. |
This is quite interesting, will check it |
Want to check if I understand this correctly: But if the "fp8" is enabled "after" startup (startup with fp16), it will be normal. |
@ClashSAN I have investigated some similar effect but different. |
cond caching. |
I confirmed your conjesture by breakpointing at the UNetModel's forward and do a switch of fp16->fp8->fp16. The result shows that context cache is permanently changed after enabling fp8. So the problem can be effectively fixed by invalidating the cond cache at the switching time. Here is the crucial evidence: |
I keep getting this error: How can I fix this? Sorry if this is the wrong place to ask, I'm somewhat new to github. Open to try any troubleshooting steps you think may help! Thanks for all your work on this! Looking forward to playing around with this. |
@Manchovies You probably have an older version of torch. The fp8 options were added in torch 2.1. If you haven't reinstalled recently you likely have 2.0 or below. I believe if you delete or rename your venv folder the webui will install 2.1 automatically. |
That worked! Thank you so much. Now, how can I get xformers installed? I remember trying in the past, and it would always uninstall the version of torch I have installed and install an older version and xformers to go with it. Maybe put "xformers==0.0.22.post7" in the requirements.txt file, or pip install xformers==0.0.22.post7 in the venv folder? Edit: that seems to have done it. went to /venv/scripts, ran "activate" in cmd, and typed "pip install xformers==0.0.22.post7" and it installed correctly without uninstalling or tampering with the torch install I had set up. Thanks again! Happy to be playing around with FP8 and enjoying the VRAM savings :) |
Hm, I thought I could run SDXL completely in VRAM with this (6 GB) but it needs --medvram for it to not OOM. Comparing it to Comfy without FP8, Comfy takes around 12 seconds, WebUi with FP8 and medvram around 9 seconds, so its a decent improvement. But idk why it won't fit completely into VRAM. |
For 6GB or worse card, you can wait for the next big update for lowvram, which will be as fast as comfy/forge for super lowvram cards. At least the author have tested it on 3060 6G |
Description
After pytorch 2.1.0, pytorch added 2 new dtype as storage type: float8_e5m2, float8_e4m3fn.[1][2]
Based on the papers which discuss the usage of fp8 as parameter/gradient for training/using NN models. I think it is worth doing some optimization with fp8 format.[3][4]
Also, some extension already support this feature too [5]
Mechanism
Although pytorch2.1.0 start supporting fp8 as storage type. We have only few hidden method for H100 to computing matmul with fp8 dtype.[6] Which means even though we can store model weights in FP8, we still need to use fp16 or bf16 to compute the result. (a.k.a upcasting)
Fortunately, pytorch's autocast can do it for us without any other changes. We just need to avoid some modules which not support fp8 storage, for example:
nn.Embedding
.And for doing this for some devices which not support autocast, I also implement a manualcast hook which support GTX16xx(or even older) series to utilize fp8 features.
Manual Cast
The idea is pretty simple, when parameters and inputs have different dtype then target dtype (defined in devices.py), cast it to target dtype.
I hooked the
Linear
,Conv
,MultiHeadAttention
,GroupNorm
,LayerNorm
layers. Which could almost covered all the layers we need. But since the hook is very simple, we can just add the layers that need to be patched in the future.This feature also allow GTX16xx series or older to utilize fp16 (even fp8) weight to reduce vram usage significantly.
List of implemented things
Performance on VRAM usage/Speed/Quality
Vram usage
Theoretically, FP8 can reduce 50% of "vram usage introduced by weights" (FP8 vs FP16). So for sd1.x with --medvram, it can save about 800MB vram when inference, or 2.5G vram for sdxl.
Here is some measurement with HWinfo64 on Windows 11 23H2/RTX 4090
Launch args:
--medvram --opt-channelslast --xformers
Table 1. Vram usage comparison between FP16/FP8
First spike is FP16 run and Second spike is FP8
Upper bound: FP16 vram usage
Lower bound: idle vram usage
Cursor box: FP8 vram usage
We can find that FP8 save 2.5 ~ 2.8G vram in SDXL and 0.8 ~ 0.9G vram in SD1.x.
Which is almost match the theory.
Speed
Since use FP8 with FP16 computation need some extra operation to cast the dtype.
It will decrease the speed (especially for smaller batch size)
Table 2. Inference speed comparison between FP16/FP8
We can find that although FP8 is slower than FP16 but the speed difference is smaller than 5%, which is acceptable.
Quality
Surprisingly, FP8 barely decrease the quality of generated image, sometime it may even improve the quality. But normally the differences are subtle.
Some comparisons here:
Image 1. Image comparison between FP16/FP8 on SDXL
Image 2. Image comparison between FP16/FP8 on SD1.x
But interestingly(or, actually expceted), use LoRA/LyCORIS with FP8 enabled will "weaken" the effect of the lora, which means you may need higher (1.5x ~ 3x) weight to reach same effect.
For example here is the comparison on LCM-lora + SD1.x models. with Euler A/cfg 2/steps8:
Image 3. LoRA effect on FP8
some reported improvement
Information from animatediff extension, use FP8 on UNet+CN+AN can achive 1024x1024 i2i on 4090. Which is impossible in the past.
Also reported from users of animatediff, use fp8+lcm lora can improve the quality of result.
Conclusion
FP8 is good and almost zero cost improvement on VRAM usage. Which is a good news for SDXL users.
Moreover, we can even use fp8 to storage the model weight directly which can reduce the ram requirement as well. (fp8 ckpt + fp8 storage enabled can achive sys ram 8G + vram 4G requirement with --medvram)
But the "weaken effect" on LoRA also reveal some possible problem after apply this method. We may need more feedback on this feature.But Good news for this problem: just disable it can avoid every affection.
And this feature also required pytorch2.1.0 which "may" be unstable. (since pytorch 2.1.1 already be stable, maybe we can consider to wait for xformers' update for pytorch 2.1.1 and then merge this PR)
Reference
Appandix
Training with FP8 also be implemented based on kohya-ss/sd-scripts codebase in my fork, I also provide some example models for it. With my implementation, users can train sdxl lora/lycoris on 6G vram card with TE/latent been cached, or train them on 8G vram card with nothing been cached with 1024x1024 arb.[7]
Checklist: